In [1]:
from glob import glob
import numpy as np
import cv2
import skimage as si
import albumentations
import plotly.express as px
import torch as t
from sergey_code.new.modules import LapCholeMultiTaskModule
from sergey_code.new.utils import rescale_to_height
In [2]:
image_paths = sorted(glob("demo/*.png"))
In [3]:
image_arrs = []
for image_path in image_paths:
image_arr = si.util.img_as_float32(
rescale_to_height(si.io.imread(image_path), 128, 1)
)
px.imshow(image_arr).show()
image_arrs.append(image_arr)
In [4]:
pad_to_shape = (128, 288)
In [5]:
padding = albumentations.core.composition.Compose(
[
albumentations.augmentations.geometric.PadIfNeeded(
*pad_to_shape,
position="center",
border_mode=cv2.BORDER_CONSTANT,
value=0,
mask_value=0,
always_apply=True,
)
],
additional_targets={"unpadded_region_mask": "mask"},
p=1,
)
In [6]:
def image_arr2image_ten_unpadded_region_mask_arr(image_arr):
image_arr = image_arr[..., :3]
image_arr = rescale_to_height(image_arr, pad_to_shape[0], 1)
unpadded_shape = image_arr.shape
image_arr = si.util.img_as_float32(image_arr)
unpadded_region_mask_arr = np.ones(image_arr.shape[:2], dtype=np.float32)
padded = padding(image=image_arr, unpadded_region_mask=unpadded_region_mask_arr)
image_arr = padded["image"]
unpadded_region_mask_arr = padded["unpadded_region_mask"] > 0.5
image_arr -= image_arr.mean(axis=(0, 1), keepdims=True)
image_arr /= image_arr.std(axis=(0, 1), keepdims=True)
image_ten = t.from_numpy(image_arr).moveaxis(-1, 0).unsqueeze(0)
return image_ten, unpadded_shape, unpadded_region_mask_arr
In [7]:
module = LapCholeMultiTaskModule.load_from_checkpoint(
"epoch=348_val_mean_mean_f1=0.82698.ckpt",
strict=False,
)
module.eval();
/home/user_libvirt/.miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:177: UserWarning: Found keys that are not in the model state dict but in the checkpoint: ['class_weights_dangerous_safe', 'loss_funs.0.weight']
In [8]:
with t.inference_mode():
for image_arr in image_arrs:
image_ten, unpadded_shape, unpadded_region_mask_arr = (
image_arr2image_ten_unpadded_region_mask_arr(image_arr)
)
probs = module.model.forward(image_ten)[0]
probs = t.nn.functional.softmax(probs, dim=1)
probs = probs[0].moveaxis(0, -1).numpy()
probs = probs[unpadded_region_mask_arr].reshape(unpadded_shape)
image_arr_w_overlay = image_arr.copy()
px.imshow(image_arr * 0.7 + probs * 0.3).show()